/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.activemq.transport.amqp.client.sasl;

import javax.security.sasl.SaslException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.qpid.proton.engine.Sasl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Manage the SASL authentication process
 */
public class SaslAuthenticator {

   private static final Logger LOG = LoggerFactory.getLogger(SaslAuthenticator.class);

   private final Sasl sasl;
   private final String username;
   private final String password;
   private final String authzid;
   private Mechanism mechanism;
   private String mechanismRestriction;

   /**
    * Create the authenticator and initialize it.
    *
    * @param sasl                 The Proton SASL entry point this class will use to manage the authentication.
    * @param username             The user name that will be used to authenticate.
    * @param password             The password that will be used to authenticate.
    * @param authzid              The authzid used when authenticating (currently only with PLAIN)
    * @param mechanismRestriction A particular mechanism to use (if offered by the server) or null to allow selection.
    */
   public SaslAuthenticator(Sasl sasl, String username, String password, String authzid, String mechanismRestriction) {
      this.sasl = sasl;
      this.username = username;
      this.password = password;
      this.authzid = authzid;
      this.mechanismRestriction = mechanismRestriction;
   }

   /**
    * Process the SASL authentication cycle until such time as an outcome is determine. This
    * method must be called by the managing entity until the return value is true indicating a
    * successful authentication or a JMSSecurityException is thrown indicating that the
    * handshake failed.
    *
    * @throws SecurityException
    */
   public boolean authenticate() throws SecurityException {
      switch (sasl.getState()) {
         case PN_SASL_IDLE:
            handleSaslInit();
            break;
         case PN_SASL_STEP:
            handleSaslStep();
            break;
         case PN_SASL_FAIL:
            handleSaslFail();
            break;
         case PN_SASL_PASS:
            return true;
         default:
      }

      return false;
   }

   private void handleSaslInit() throws SecurityException {
      try {
         String[] remoteMechanisms = sasl.getRemoteMechanisms();
         if (remoteMechanisms != null && remoteMechanisms.length != 0) {
            mechanism = findMatchingMechanism(remoteMechanisms);
            if (mechanism != null) {
               mechanism.setUsername(username);
               mechanism.setPassword(password);
               mechanism.setAuthzid(authzid);
               // TODO - set additional options from URI.
               // TODO - set a host value.

               sasl.setMechanisms(mechanism.getName());
               byte[] response = mechanism.getInitialResponse();
               if (response != null && response.length != 0) {
                  sasl.send(response, 0, response.length);
               }
            } else {
               // TODO - Better error message.
               throw new SecurityException("Could not find a matching SASL mechanism for the remote peer.");
            }
         }
      } catch (SaslException se) {
         // TODO - Better error message.
         SecurityException jmsse = new SecurityException("Exception while processing SASL init.");
         jmsse.initCause(se);
         throw jmsse;
      }
   }

   private Mechanism findMatchingMechanism(String... remoteMechanisms) {

      Mechanism match = null;
      List<Mechanism> found = new ArrayList<>();

      for (String remoteMechanism : remoteMechanisms) {
         if (mechanismRestriction != null && !mechanismRestriction.equals(remoteMechanism)) {
            LOG.debug("Skipping {} mechanism because it is not the configured mechanism restriction {}", remoteMechanism, mechanismRestriction);
            continue;
         }

         Mechanism mechanism = null;
         if (remoteMechanism.equalsIgnoreCase("PLAIN")) {
            mechanism = new PlainMechanism();
         } else if (remoteMechanism.equalsIgnoreCase("ANONYMOUS")) {
            mechanism = new AnonymousMechanism();
         } else if (remoteMechanism.equalsIgnoreCase("CRAM-MD5")) {
            mechanism = new CramMD5Mechanism();
         } else {
            LOG.debug("Unknown remote mechanism {}, skipping", remoteMechanism);
            continue;
         }

         if (mechanism.isApplicable(username, password)) {
            found.add(mechanism);
         }
      }

      if (!found.isEmpty()) {
         // Sorts by priority using Mechanism comparison and return the last value in
         // list which is the Mechanism deemed to be the highest priority match.
         Collections.sort(found);
         match = found.get(found.size() - 1);
      }

      LOG.info("Best match for SASL auth was: {}", match);

      return match;
   }

   private void handleSaslStep() throws SecurityException {
      try {
         if (sasl.pending() != 0) {
            byte[] challenge = new byte[sasl.pending()];
            sasl.recv(challenge, 0, challenge.length);
            byte[] response = mechanism.getChallengeResponse(challenge);
            sasl.send(response, 0, response.length);
         }
      } catch (SaslException se) {
         // TODO - Better error message.
         SecurityException jmsse = new SecurityException("Exception while processing SASL step.");
         jmsse.initCause(se);
         throw jmsse;
      }
   }

   private void handleSaslFail() throws SecurityException {
      // TODO - Better error message.
      throw new SecurityException("Client failed to authenticate");
   }
}
